# Copyright 2007 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Handles reading and writing from our svn userdb.

This class is meant to be replaceable with another one that implements access to
whatever database type/format you want to use. The current implementation works
fine for us, but you might want to use mysql, or pickles, or whatever other
access method
"""

import crypt
import os
import random
import shutil
import tempfile
import time
import re
from errno import ENOENT, ENOTEMPTY

# this is used to keep track of hand created users, that we don't want to
# auto delete as part of a sync with an external user database.
SVN_ONLY_FLAG = "user_is_svn_only;not_autosynced"


class NoUserError(Exception):
  def __init__(self, username, child):
    Exception.__init__(self, 'No such user ' + username)
    self.username = username
    self.child = child


def GeneratePassword():
  """Return a new 8-character random password for subversion.
  """
  # 'l', 'O', 'o' are removed not to confuse them with '1' and '0'
  # : " \ ' ` are also removed
  char_list = ( "ABCDEFGHIJKLMNPQRSTUVWXYZabcdefghijkmnpqrstuvwxyz"
                "0123456789!@#$%^&*()-=_+[]|{};<>?,./~" )
  pw_list = []
  for i in xrange(0,8):
    pw_list.append(random.choice(char_list))
  return ''.join(pw_list)

def GenerateSalt():
  """Generate a crypt()-style 2 byte salt string."""
  salty_characters = ("abcdefghijklmnopqrstuvwxyz"
                      "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
                      "0123456789./")

  salt = []
  salt.append(random.choice(salty_characters))
  salt.append(random.choice(salty_characters))

  return "".join(salt)

def HashPassword(password):
  """Return a crypt()-style hash of the password, a la htpasswd,
  suitable for use by Apache httpd's BasicAuth mechanism.
  """
  return crypt.crypt(password, GenerateSalt())

def WriteUserData(filepath, filename, data):
  """Creates a tmpfile in filepath, write data and rename as filename."""

  tmpfile = None
  try:
    fd,tmpfile = tempfile.mkstemp(dir=filepath, prefix=filename, text=True)
    # mkstemp ignores the umask and makes a file not readable by group or others
    fh = os.fdopen(fd, "w")
    fh.write("\n".join(data) + "\n")
    fh.close()
    os.chmod(tmpfile, 0664)
    os.rename(tmpfile, filepath + "/" + filename)
    tmpfile = None
  finally:
    if tmpfile is not None:
      try:
        os.unlink(tmpfile)
      except:
        pass


class UserDB(object):
  """A class for interfacing with mod_authn_dir-style directory data.

  This object can be replaced with another one to support any kind of storage
  format (flat file, pickle, SQL, etc...)

  It contains all the methods necessary to access data in our database
  """

  def __init__(self, dbinit, create_db=False):
    """Initializes the DB object with whatever DB specific data is passed.

    dbinit can contain whatever init data this object needs (dir, file, etc).
    create_db set to True allows for the database to be created at init/open
    time

    Here, we just take a directory name, and optionally create the said
    directory if needed.
    """

    if not os.path.exists(dbinit):
      if create_db:
        # this userdb implementation just needs a top level directory as to
        # be valid
        os.makedirs(dbinit)
      else:
        raise "userdb init failed: no such directory " + dbinit

    self._dirpath = dbinit

    # depending on the caller, we may build and cache a list of groups 
    self._usergroups = {}


  def _HashUserDir(self, username):
    """Create a hashed directory for a given username.

    We use the first 3 letters of the username, which gives a maximum number
    of 17,576 directories, and probably around 5000 directories max in most
    configurations (we do want fewer than 32,000 dirs which is the limit for
    ext2/ext3 without extended attributes)

    This in turn, gives 100 to 500 user sub-directories per hash bucket if you
    have one million users
    This is not the best hashing function, but it's simple and allows for easy
    lookup by an admin or a shell script
    """

    userdir = os.path.join(self._dirpath, username[:3], username)
    return userdir

  def _GetUserDir(self, username):
    """Return a hash path to a user directory, or None if not present.
    """

    userdir = self._HashUserDir(username)
    if os.path.exists(userdir):
      return userdir
    else:
      return None


  def _MakeUserDir(self, username):
    """Make a new user directory if needed, and return the hashed location.

    In other words, you should call this function just like you'd call
    _HashUserDir, and it will just happen to create the hashed directory for
    you if needed.
    """

    userdir = self._HashUserDir(username)
    hash_bucket = os.path.dirname(userdir)
    # make the parent and current directories if needed
    try:
      os.mkdir(hash_bucket)
      os.chmod(hash_bucket, 02775)
    except OSError:
      pass

    try:
      os.mkdir(userdir)
      os.chmod(userdir, 02775)
    except OSError:
      pass

    return userdir


  def ReadUserPassword(self, username):
    """Returns a (username, plaintext_password, crypted_password) tuple.

    Note: if the username does not exist a partially empty tuple of the
    form (username, None, None) is returned.
    """
    plaintext = None
    crypted = None

    userdir = self._HashUserDir(username)

    # Turn ENOENT into NoUserError, but let other errors out.
    try:
      fp = open(userdir + "/password", "r")
    except IOError, e:
      if e.errno != ENOENT:
        raise
      raise NoUserError(username, e)

    for line in fp:
      if not line or line[0] == "#":
        continue
      parts = line.split(':')
      crypted = parts[0]
      plaintext = parts[1]
      # Ignore all lines after this one (not that there should be any)
      break

    return (username, plaintext, crypted)


  def WriteUserPassword(self, username, password):
    """Create a new user entry if needed and sets/changes the password.
    """

    hashed_pwd = HashPassword(password)
    pwd_line = "%s:%s:%s\n" % (hashed_pwd, password, username)
    WriteUserData(self._MakeUserDir(username), "password", [ pwd_line ])

    return (username, password, hashed_pwd)


  def ReadUserGroups(self, username):
    """Get a list of all groups for a user, or return None if no such user.

    Returns None if the user does not exist, or an empty list of the user has
    no groups
    """

    userdir = self._GetUserDir(username)
    # We accept users with no group files and just say that they have no groups
    if userdir is None:
      return []

    grouplist = []
    try:
      f = open(userdir + "/groups", "r")
    except IOError, e:
      if e.errno != ENOENT:
        raise
      # Non-existent groups file is no error; just means no groups.
      return []
    buf = f.read()
    f.close()
    lines = buf.split('\n')
    for line in lines:
      if not line or line[0] == "#":
        continue
      grouplist.append(line)

    return grouplist


  def WriteUserGroups(self, username, grouplist):
    """Create a new user entry if needed and sets/changes the group list.
    """

    WriteUserData(self._MakeUserDir(username), "groups", grouplist)


  def DeleteUser(self, username):
    """delete user hash dir with all data, and hash bucket, if possible.

    Return None if no such user existed
    """

    userdir = self._GetUserDir(username)
    if userdir is None:
      return None

    shutil.rmtree(userdir)

    # Try to delete the hash bucket, just in case it became empty
    hash_bucket = os.path.dirname(userdir)
    try:
      os.rmdir(hash_bucket)
    except OSError, e:
      if e.errno != ENOTEMPTY:
        raise

  def DoesUserExist(self, username, require_password=False):
    """simply does what it says :).

       Obviously, a user can be added or removed between the time you run
       this call, and the time you use the data you got from it. Keep that in
       mind.

       Setting require_password to True if asking for a user with password
    """

    if self._GetUserDir(username):
      if not require_password:
        return True
      else:
        try:
          self.ReadUserPassword(username)
          return True
        except IOError:
          return False
    else:
      return False

  def GetAllUsers(self, require_password=True, return_not_autosynced=True):
    """user list generator.

    require_password: if True, yield only users with password entries.
    return_not_autosynced: If False, do not return users with SVN_ONLY_FLAG
         (useful for not deleting users if they are autosynced from another DB)

    """

    for hashdir in os.listdir(self._dirpath):
      # our hashdirs are 1 to 3 letters, skip all others
      if len(hashdir) > 3:
        continue

      hashdir = self._dirpath + "/" + hashdir
      # skip whatever could be there, but isn't a directory (just being safe)
      if not os.path.isdir(hashdir):
        continue

      for user in os.listdir(hashdir):

        # if we want to exclude non autosynced users:
        if return_not_autosynced is False:
          if os.path.isfile("%s/%s/%s" % (hashdir, user, SVN_ONLY_FLAG)):
            # we don't want to return not_autosynced users, and we found the
            # autosynced flag, so:
            continue

        if not require_password:
          yield user
        else:
          # Nice, re-uses password reading code, but too slow
          #try:
          #  self.ReadUserPassword(user)
          #except IOError:
          #  pass

          # this is 3-5x faster
          if os.path.isfile("%s/%s/password" % (hashdir,user)):
            yield user


  def LockUser(self, username):
    """Pretty much does what it says."""

    (username, plaintext, crypted) = self.ReadUserPassword(username)
    pwd_line = "!!%s:!!%s:%s\n" % (crypted, plaintext, username)
    WriteUserData(self._MakeUserDir(username), "password", [ pwd_line ])

  def UnLockUser(self, username):
    """Pretty much does what it says. Invariant if the user isn't locked"""

    (username, plaintext, crypted) = self.ReadUserPassword(username)
    plaintext = re.sub(r'^!!', '', plaintext)
    crypted = re.sub(r'^!!', '', crypted)
    pwd_line = "%s:%s:%s\n" % (crypted, plaintext, username)
    WriteUserData(self._MakeUserDir(username), "password", [ pwd_line ])

  def RandomizeUserPassword(self, username):
    """Create a new random password for user 'username'.

    Returns the new (username, plain, crypted) tuple for this user.
    Note: as a side effect user 'username' is created if necessary.
    """
    return self.WriteUserPassword(username, GeneratePassword())

  def CreateSvnOnlyUser(self, username):
    """Creates a user with the 'svn only' flag and a random password.

    This special flag is used to indicate to auto DB syncing scripts that
    this user should not be removed if it's absent from the DB we're syncing
    from.

    Returns the new (username, plain, crypted) tuple for the user.
    """
    tuplet = self.WriteUserPassword(username, GeneratePassword())
    WriteUserData(self._HashUserDir(username), SVN_ONLY_FLAG,
                                                               "not autosynced")
    return tuplet


  def UserGroups(self, user):
    """Like ReadUserGroups, but cache the output."""

    if user not in self._usergroups:
      self._usergroups[user] = self.ReadUserGroups(user)

    return self._usergroups[user]

  def UserInGroup(self, user, group):
    return group in self.UserGroups(user)
